package com.tang.kite.executor.defaults

import com.tang.kite.executor.Executor
import com.tang.kite.sql.statement.BatchSqlStatement
import com.tang.kite.sql.statement.SqlStatement
import com.tang.kite.transaction.Transaction
import com.tang.kite.utils.Statements
import com.tang.kite.utils.resultset.ResultSetHandlers
import org.slf4j.LoggerFactory
import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.Statement

/**
 * @author Tang
 */
class DefaultExecutor(private val transaction: Transaction) : Executor {

    private val logger = LoggerFactory.getLogger(DefaultExecutor::class.java)

    override fun getConnection(): Connection {
        return transaction.getConnection()
    }

    override fun <T> count(statement: SqlStatement, type: Class<T>): Long {
        val connection = getConnection()
        val preparedStatement = connection.prepareStatement(statement.sql)
        statement.setValues(preparedStatement)
        return executeTemplate(
            preparedStatement = preparedStatement,
            connection = connection,
            statementType = "count",
            runAction = {
                val resultSet = it.executeQuery()
                ResultSetHandlers.getCount(resultSet)
            },
            defaultValue = 0L
        )
    }

    override fun <T> query(statement: SqlStatement, type: Class<T>): List<T> {
        val connection = getConnection()
        val preparedStatement = connection.prepareStatement(statement.sql)
        statement.setValues(preparedStatement)
        return executeTemplate(
            preparedStatement = preparedStatement,
            connection = connection,
            statementType = "query",
            runAction = {
                val resultSet = it.executeQuery()
                ResultSetHandlers.getList(resultSet, type)
            },
            defaultValue = mutableListOf()
        )
    }

    override fun update(statement: SqlStatement, parameter: Any): Int {
        val connection = getConnection()
        val autoGeneratedKeys = getAutoGeneratedKeys(statement.sql, parameter)
        val preparedStatement = connection.prepareStatement(statement.sql, autoGeneratedKeys)
        statement.setValues(preparedStatement)
        return executeTemplate(
            preparedStatement = preparedStatement,
            connection = connection,
            statementType = "update",
            runAction = { it.executeUpdate() },
            alsoAction = { ResultSetHandlers.setGeneratedKey(statement.sql, it, parameter) },
            defaultValue = 0
        )
    }

    override fun update(batchSqlStatement: BatchSqlStatement, parameters: List<Any>): Int {
        val connection = getConnection()
        var totalUpdatedRows = 0
        val autoGeneratedKeys = getAutoGeneratedKeys(batchSqlStatement.sql, parameters.first())
        val preparedStatement = connection.prepareStatement(batchSqlStatement.sql, autoGeneratedKeys)
        for (parameter in batchSqlStatement.parameters) {
            Statements.setValues(preparedStatement, parameter)
            preparedStatement.addBatch()
        }
        totalUpdatedRows += executeTemplate(
            preparedStatement = preparedStatement,
            connection = connection,
            statementType = "batch update",
            runAction = { it.executeBatch().sum() },
            alsoAction = { ResultSetHandlers.setGeneratedKey(batchSqlStatement.sql, it, parameters) },
            defaultValue = 0
        )
        return totalUpdatedRows
    }

    override fun update(statements: List<SqlStatement>, parameters: List<Any>): Int {
        val connection = getConnection()
        var totalUpdatedRows = 0
        for (i in statements.indices) {
            val statement = statements[i]
            val parameter = parameters[i]
            val autoGeneratedKeys = getAutoGeneratedKeys(statement.sql, parameter)
            val preparedStatement = connection.prepareStatement(statement.sql, autoGeneratedKeys)
            statement.setValues(preparedStatement)
            totalUpdatedRows += executeTemplate(
                preparedStatement = preparedStatement,
                connection = connection,
                statementType = "batch update",
                runAction = { it.executeUpdate() },
                alsoAction = { ResultSetHandlers.setGeneratedKey(statement.sql, it, parameter) },
                defaultValue = 0
            )
        }
        return totalUpdatedRows
    }

    override fun commit() {
        transaction.commit()
    }

    override fun rollback() {
        transaction.rollback()
    }

    override fun close() {
        transaction.close()
    }

    private fun <R> executeTemplate(
        preparedStatement: PreparedStatement,
        connection: Connection,
        statementType: String,
        runAction: (PreparedStatement) -> R,
        alsoAction: (PreparedStatement) -> Unit = {},
        defaultValue: R
    ): R {
        return runCatching {
            runAction(preparedStatement)
        }.onFailure {
            it.printStackTrace()
            connection.rollback()
        }.also {
            alsoAction(preparedStatement)
            logger.debug("Closing {} prepared statement [{}]", statementType, preparedStatement)
            preparedStatement.close()
        }.getOrDefault(defaultValue)
    }

    private fun getAutoGeneratedKeys(sql: String, parameter: Any): Int {
        return if (ResultSetHandlers.hasGeneratedKey(sql, parameter)) {
            Statement.RETURN_GENERATED_KEYS
        } else {
            Statement.NO_GENERATED_KEYS
        }
    }

}
